from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from config.settings import DATA_DIR

def get_dataloader(split: str, model_type: str, batch_size: int = 32, shuffle: bool = True):
    transform = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    dataset_path = DATA_DIR / model_type / split
    dataset = datasets.ImageFolder(dataset_path, transform=transform)
    return DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)